{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# QAT with convolution and batchnorm fused constraints"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "本教程以Conv2D和BatchNorm的融合为例，介绍如何使用PaddleSlim接口快速为量化训练添加训练约束(constraints)。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. 添加依赖"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import paddle\n",
    "from paddle.vision.models import resnet18\n",
    "from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver\n",
    "from paddleslim.quant import SlimQuantConfig as QuantConfig\n",
    "from paddleslim.quant import SlimQAT\n",
    "from paddleslim.quant.constraints import FreezedConvBNConstraint\n",
    "paddle.set_device(\"cpu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. 构造模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = resnet18()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. 量化训练\n",
    "\n",
    "**配置**\n",
    "\n",
    "构造量化配置实例，并将激活和权重的量化方式指定为基础的 AbsMax 量化策略。然后，调用 `add_constraints` 配置一个Constraints实例。\n",
    "最后，使用量化配置构造SlimQAT对象。代码如下所示："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.9)\n",
    "q_config = QuantConfig(activation=quanter, weight=quanter)\n",
    "q_config.add_constraints(FreezedConvBNConstraint())\n",
    "qat = SlimQAT(q_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**量化**\n",
    "\n",
    "调用SlimQAT对象的quantize方法，将模型转为用于量化训练的模型。\n",
    "用户需要指定一个inputs, 用于推断分析动态图的执行拓扑图，以便自动检测出所有Conv2D和BN的组合。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = paddle.rand([1, 3, 224, 224])\n",
    "quant_model = qat.quantize(model, inplace=True, inputs=x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在该步骤，所有的Conv2D和BN的组合，都会被替换为QuantedConv2DBatchNorm layer。在QuantedConv2DBatchNorm中，参考 [Quantizing deep convolutional networks for efficient inference: A whitepaper](https://arxiv.org/abs/1806.08342) ，在量化训练过程中模拟Conv2D和BatchNorm的融合。原理如下图所示：\n",
    "\n",
    "<div align=\"center\"><img src=\"https://user-images.githubusercontent.com/7534971/221735696-f78fdaff-2067-4a76-bb92-c99ae4740f2f.png\" width=\"500\"></div>\n",
    "<div align=\"center\">Conv BN 量化训练矫正方案示意图</div>\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**训练**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "quant_model.train()\n",
    "out = quant_model(x)\n",
    "out.backward()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**查看量化训练模型结构**\n",
    "\n",
    "直接在终端输出量化训练模型的结构，或者将模型保存到文件系统，并使用[netron](https://netron.app/)查看模型结构。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(quant_model)\n",
    "quant_model.eval()\n",
    "paddle.jit.save(quant_model, \"./qat_model\", input_spec=[x])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "模拟量化训练模型结构如下图所示：\n",
    "\n",
    "<div align=\"center\"><img width=\"300\" alt=\"image\" src=\"https://user-images.githubusercontent.com/7534971/222439015-375211b2-4004-4452-84a8-79d3d94ccf42.png\"></div>\n",
    "<div align=\"center\">模拟量化模型结构示意图</div>\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**保存推理模型**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "infer_model = qat.convert(quant_model, inplace=True)\n",
    "print(infer_model)\n",
    "paddle.jit.save(infer_model, \"./infer_model\", input_spec=[x])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "推理模型结构如下图所示：\n",
    "\n",
    "<div align=\"center\"><img width=\"300\" alt=\"image\" src=\"https://user-images.githubusercontent.com/7534971/222439260-33831915-10fe-41d8-a822-01537dcbce67.png\"></div>\n",
    "<div align=\"center\">量化推理模型结构示意图</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  },
  "vscode": {
   "interpreter": {
    "hash": "2f394aca7ca06fed1e6064aef884364492d7cdda3614a461e02e6407fc40ba69"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
